package com.ml4ai.junit;

import com.ml4ai.nn.core.Toolkit;
import com.ml4ai.nn.core.Variable;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;

public class Backward {

    @Test
    public void test() {

        Variable x = new Variable(Nd4j.rand(new int[]{5, 3, 4}));
        Variable y = new Variable(Nd4j.rand(new int[]{7, 3, 4}));
        Variable z = x.connect(y, 0);
        Variable sum = z.square().sum();
        Toolkit tool = new Toolkit();
        tool.grad2zero(sum);
        tool.backward(sum);
        System.out.println("开始记录");
        System.out.println("原始数据");
        System.out.println(x.data.tensor);
        System.out.println("梯度");
        System.out.println(x.grad.tensor);
        System.out.println("++++");
        System.out.println("原始数据");
        System.out.println(y.data.tensor);
        System.out.println("梯度");
        System.out.println(y.grad.tensor);

    }

}
